Skip to content

Conversation

@wsmoses
Copy link
Member

@wsmoses wsmoses commented Jan 4, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jan 4, 2025

@llvm/pr-subscribers-mlir

Author: William Moses (wsmoses)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/121624.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+4-1)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+59-46)
  • (modified) mlir/test/Target/LLVMIR/Import/import-failure.ll (-9)
  • (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+10)
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index eea0647895b01b..880e8201318e6c 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -319,9 +319,12 @@ class ModuleImport {
   /// Appends the converted result type and operands of `callInst` to the
   /// `types` and `operands` arrays. For indirect calls, the method additionally
   /// inserts the called function at the beginning of the `operands` array.
+  /// If `handleAsm` is set to false (the default), it will err if the handler
+  /// is an inline asm which isn't convertible to MLIR as a value.
   LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
                                            SmallVectorImpl<Type> &types,
-                                           SmallVectorImpl<Value> &operands);
+                                           SmallVectorImpl<Value> &operands,
+                                           bool handleAsm = false);
   /// Converts the parameter attributes attached to `func` and adds them to the
   /// `funcOp`.
   void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index b0d5e635248d3f..66cfb3f9dca110 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1473,18 +1473,19 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
   return success();
 }
 
-LogicalResult
-ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
-                                         SmallVectorImpl<Type> &types,
-                                         SmallVectorImpl<Value> &operands) {
+LogicalResult ModuleImport::convertCallTypeAndOperands(
+    llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<Value> &operands, bool handleAsm) {
   if (!callInst->getType()->isVoidTy())
     types.push_back(convertType(callInst->getType()));
 
   if (!callInst->getCalledFunction()) {
-    FailureOr<Value> called = convertValue(callInst->getCalledOperand());
-    if (failed(called))
-      return failure();
-    operands.push_back(*called);
+    if (!handleAsm || !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
+      FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+      if (failed(called))
+        return failure();
+      operands.push_back(*called);
+    }
   }
   SmallVector<llvm::Value *> args(callInst->args());
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
@@ -1579,7 +1580,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
 
     SmallVector<Type> types;
     SmallVector<Value> operands;
-    if (failed(convertCallTypeAndOperands(callInst, types, operands)))
+    if (failed(convertCallTypeAndOperands(callInst, types, operands, true)))
       return failure();
 
     auto funcTy =
@@ -1587,45 +1588,57 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     if (!funcTy)
       return failure();
 
-    CallOp callOp;
-
-    if (llvm::Function *callee = callInst->getCalledFunction()) {
-      callOp = builder.create<CallOp>(
-          loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
-          operands);
+    if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+      InlineAsmOp callOp = builder.create<InlineAsmOp>(
+          loc, funcTy.getReturnType(), operands,
+          builder.getStringAttr(asmI->getAsmString()),
+          builder.getStringAttr(asmI->getConstraintString()), nullptr, nullptr,
+          nullptr, nullptr);
+      if (!callInst->getType()->isVoidTy())
+        mapValue(inst, callOp.getResult(0));
+      else
+        mapNoResultOp(inst, callOp);
     } else {
-      callOp = builder.create<CallOp>(loc, funcTy, operands);
+      CallOp callOp;
+
+      if (llvm::Function *callee = callInst->getCalledFunction()) {
+        callOp = builder.create<CallOp>(
+            loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
+            operands);
+      } else {
+        callOp = builder.create<CallOp>(loc, funcTy, operands);
+      }
+      callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
+      callOp.setTailCallKind(
+          convertTailCallKindFromLLVM(callInst->getTailCallKind()));
+      setFastmathFlagsAttr(inst, callOp);
+
+      // Handle function attributes.
+      if (callInst->hasFnAttr(llvm::Attribute::Convergent))
+        callOp.setConvergent(true);
+      if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
+        callOp.setNoUnwind(true);
+      if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
+        callOp.setWillReturn(true);
+
+      llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
+      ModRefInfo othermem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::Other));
+      ModRefInfo argMem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
+      ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
+      auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
+                                            argMem, inaccessibleMem);
+      // Only set the attribute when it does not match the default value.
+      if (!memAttr.isReadWrite())
+        callOp.setMemoryEffectsAttr(memAttr);
+
+      if (!callInst->getType()->isVoidTy())
+        mapValue(inst, callOp.getResult());
+      else
+        mapNoResultOp(inst, callOp);
     }
-    callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
-    callOp.setTailCallKind(
-        convertTailCallKindFromLLVM(callInst->getTailCallKind()));
-    setFastmathFlagsAttr(inst, callOp);
-
-    // Handle function attributes.
-    if (callInst->hasFnAttr(llvm::Attribute::Convergent))
-      callOp.setConvergent(true);
-    if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
-      callOp.setNoUnwind(true);
-    if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
-      callOp.setWillReturn(true);
-
-    llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
-    ModRefInfo othermem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::Other));
-    ModRefInfo argMem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
-    ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
-    auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
-                                          inaccessibleMem);
-    // Only set the attribute when it does not match the default value.
-    if (!memAttr.isReadWrite())
-      callOp.setMemoryEffectsAttr(memAttr);
-
-    if (!callInst->getType()->isVoidTy())
-      mapValue(inst, callOp.getResult());
-    else
-      mapNoResultOp(inst, callOp);
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::LandingPad) {
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 6bde174642d540..b616cb81e0a8a5 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -12,15 +12,6 @@ bb2:
 
 ; // -----
 
-; CHECK:      <unknown>
-; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
-define i32 @unhandled_value(i32 %arg1) {
-  %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
-  ret i32 %1
-}
-
-; // -----
-
 ; CHECK:      <unknown>
 ; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported
 ; CHECK:      <unknown>
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index fff48bbc486bc1..42b2cfff9611a3 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -535,6 +535,16 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) {
 
 ; // -----
 
+; CHECK-LABEL: @inlineasm
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+define i32 @inlineasm(i32 %arg1) {
+  %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
+  ; CHECK:  llvm.inline_asm has_side_effects asm_dialect = att "bswap $0", "=r,r" (%[[ARG1]]) : (i32) -> (i32)
+  ret i32 %1
+}
+
+; // -----
+
 ; CHECK-LABEL: @gep_static_idx
 ; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
 define void @gep_static_idx(ptr %ptr) {

@llvmbot
Copy link
Member

llvmbot commented Jan 4, 2025

@llvm/pr-subscribers-mlir-llvm

Author: William Moses (wsmoses)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/121624.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+4-1)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+59-46)
  • (modified) mlir/test/Target/LLVMIR/Import/import-failure.ll (-9)
  • (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+10)
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index eea0647895b01b..880e8201318e6c 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -319,9 +319,12 @@ class ModuleImport {
   /// Appends the converted result type and operands of `callInst` to the
   /// `types` and `operands` arrays. For indirect calls, the method additionally
   /// inserts the called function at the beginning of the `operands` array.
+  /// If `handleAsm` is set to false (the default), it will err if the handler
+  /// is an inline asm which isn't convertible to MLIR as a value.
   LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
                                            SmallVectorImpl<Type> &types,
-                                           SmallVectorImpl<Value> &operands);
+                                           SmallVectorImpl<Value> &operands,
+                                           bool handleAsm = false);
   /// Converts the parameter attributes attached to `func` and adds them to the
   /// `funcOp`.
   void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index b0d5e635248d3f..66cfb3f9dca110 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1473,18 +1473,19 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
   return success();
 }
 
-LogicalResult
-ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
-                                         SmallVectorImpl<Type> &types,
-                                         SmallVectorImpl<Value> &operands) {
+LogicalResult ModuleImport::convertCallTypeAndOperands(
+    llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
+    SmallVectorImpl<Value> &operands, bool handleAsm) {
   if (!callInst->getType()->isVoidTy())
     types.push_back(convertType(callInst->getType()));
 
   if (!callInst->getCalledFunction()) {
-    FailureOr<Value> called = convertValue(callInst->getCalledOperand());
-    if (failed(called))
-      return failure();
-    operands.push_back(*called);
+    if (!handleAsm || !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
+      FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+      if (failed(called))
+        return failure();
+      operands.push_back(*called);
+    }
   }
   SmallVector<llvm::Value *> args(callInst->args());
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
@@ -1579,7 +1580,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
 
     SmallVector<Type> types;
     SmallVector<Value> operands;
-    if (failed(convertCallTypeAndOperands(callInst, types, operands)))
+    if (failed(convertCallTypeAndOperands(callInst, types, operands, true)))
       return failure();
 
     auto funcTy =
@@ -1587,45 +1588,57 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     if (!funcTy)
       return failure();
 
-    CallOp callOp;
-
-    if (llvm::Function *callee = callInst->getCalledFunction()) {
-      callOp = builder.create<CallOp>(
-          loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
-          operands);
+    if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+      InlineAsmOp callOp = builder.create<InlineAsmOp>(
+          loc, funcTy.getReturnType(), operands,
+          builder.getStringAttr(asmI->getAsmString()),
+          builder.getStringAttr(asmI->getConstraintString()), nullptr, nullptr,
+          nullptr, nullptr);
+      if (!callInst->getType()->isVoidTy())
+        mapValue(inst, callOp.getResult(0));
+      else
+        mapNoResultOp(inst, callOp);
     } else {
-      callOp = builder.create<CallOp>(loc, funcTy, operands);
+      CallOp callOp;
+
+      if (llvm::Function *callee = callInst->getCalledFunction()) {
+        callOp = builder.create<CallOp>(
+            loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
+            operands);
+      } else {
+        callOp = builder.create<CallOp>(loc, funcTy, operands);
+      }
+      callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
+      callOp.setTailCallKind(
+          convertTailCallKindFromLLVM(callInst->getTailCallKind()));
+      setFastmathFlagsAttr(inst, callOp);
+
+      // Handle function attributes.
+      if (callInst->hasFnAttr(llvm::Attribute::Convergent))
+        callOp.setConvergent(true);
+      if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
+        callOp.setNoUnwind(true);
+      if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
+        callOp.setWillReturn(true);
+
+      llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
+      ModRefInfo othermem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::Other));
+      ModRefInfo argMem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
+      ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
+          memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
+      auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
+                                            argMem, inaccessibleMem);
+      // Only set the attribute when it does not match the default value.
+      if (!memAttr.isReadWrite())
+        callOp.setMemoryEffectsAttr(memAttr);
+
+      if (!callInst->getType()->isVoidTy())
+        mapValue(inst, callOp.getResult());
+      else
+        mapNoResultOp(inst, callOp);
     }
-    callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
-    callOp.setTailCallKind(
-        convertTailCallKindFromLLVM(callInst->getTailCallKind()));
-    setFastmathFlagsAttr(inst, callOp);
-
-    // Handle function attributes.
-    if (callInst->hasFnAttr(llvm::Attribute::Convergent))
-      callOp.setConvergent(true);
-    if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
-      callOp.setNoUnwind(true);
-    if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
-      callOp.setWillReturn(true);
-
-    llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
-    ModRefInfo othermem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::Other));
-    ModRefInfo argMem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
-    ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
-        memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
-    auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
-                                          inaccessibleMem);
-    // Only set the attribute when it does not match the default value.
-    if (!memAttr.isReadWrite())
-      callOp.setMemoryEffectsAttr(memAttr);
-
-    if (!callInst->getType()->isVoidTy())
-      mapValue(inst, callOp.getResult());
-    else
-      mapNoResultOp(inst, callOp);
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::LandingPad) {
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 6bde174642d540..b616cb81e0a8a5 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -12,15 +12,6 @@ bb2:
 
 ; // -----
 
-; CHECK:      <unknown>
-; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
-define i32 @unhandled_value(i32 %arg1) {
-  %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
-  ret i32 %1
-}
-
-; // -----
-
 ; CHECK:      <unknown>
 ; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported
 ; CHECK:      <unknown>
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index fff48bbc486bc1..42b2cfff9611a3 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -535,6 +535,16 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) {
 
 ; // -----
 
+; CHECK-LABEL: @inlineasm
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+define i32 @inlineasm(i32 %arg1) {
+  %1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
+  ; CHECK:  llvm.inline_asm has_side_effects asm_dialect = att "bswap $0", "=r,r" (%[[ARG1]]) : (i32) -> (i32)
+  ret i32 %1
+}
+
+; // -----
+
 ; CHECK-LABEL: @gep_static_idx
 ; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
 define void @gep_static_idx(ptr %ptr) {

@wsmoses wsmoses force-pushed the mlirasm branch 2 times, most recently from 594bedc to 2b6ab6a Compare January 4, 2025 06:09
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding inline asm support!

InlineAsmOp callOp = builder.create<InlineAsmOp>(
loc, funcTy.getReturnType(), operands,
builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getConstraintString()), nullptr, nullptr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you prefix the nullptr arguments with the argument name. E.g. /*has_side_effects=*/false etc.

Shouldn't we extract these extra arguments from the assembler instruction? Or at least use conservative values such as true for has side effects?

; CHECK: <unknown>
; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
define i32 @unhandled_value(i32 %arg1) {
%1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we could keep a test that runs into an unhandled value error. Are you by chance aware of an alternative test case? All cases I can think of run into another error first (e.g. in the unhandled instruction error above).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that I could think of. As an example, tried the callbr version (since I only added handling for call), but that first hit a bad conversion for the basicblock which was considered a constant not value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's drop the test then.

@wsmoses
Copy link
Member Author

wsmoses commented Jan 4, 2025

okay addressed all your comments, @gysit!

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments!

LGTM, modulo last comment.

Also the commit message should be reformatted to:

[MLIR] Enable importing inlineasm calls

to match the LLVM style.

; CHECK: <unknown>
; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
define i32 @unhandled_value(i32 %arg1) {
%1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's drop the test then.

@wsmoses wsmoses merged commit b5f2167 into llvm:main Jan 5, 2025
8 checks passed
@wsmoses wsmoses deleted the mlirasm branch January 5, 2025 16:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants